from datasets import Dataset, load_dataset, concatenate_datasets
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd

dataset_tomi = load_dataset("json", data_files=".../data/test_balanced.json", split="train[:800]")

def add_combined_column_tomi(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["containers"]
        if isinstance(choices_text, list):
            choices_text = ", ".join(choices_text)
            
        # Create combined text
        example["messages"] = [
        {"role": "user", "content": f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"},
        {"role": "assistant", "content": example['answer']}
        ]
        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset_tomi = add_combined_column_tomi(dataset_tomi)
print(dataset_tomi)
print(dataset_tomi[0])


dataset1 = load_dataset(".../ToM_data/Hi-ToM", split="train[:60]")
dataset2 = load_dataset(".../ToM_data/Hi-ToM", split="train[100:160]")
dataset3 = load_dataset(".../ToM_data/Hi-ToM", split="train[200:260]")

dataset4 = load_dataset(".../ToM_data/Hi-ToM", split="train[600:660]")
dataset5 = load_dataset(".../ToM_data/Hi-ToM", split="train[700:760]")
dataset6 = load_dataset(".../ToM_data/Hi-ToM", split="train[800:860]")

dataset_hitom = concatenate_datasets([dataset1, dataset2, dataset3, dataset4, dataset5, dataset6])

def add_combined_column_hitom(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["choices"]
        # Create combined text
        example["messages"] = [
        {"role": "user", "content": f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"},
        {"role": "assistant", "content": example['answer']}
        ]
        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset_hitom = add_combined_column_hitom(dataset_hitom)
print(dataset_hitom)
print(dataset_hitom[0])


dataset_exploretom = load_dataset("csv", data_files=".../ToM_data/ExploreToM/ExploreToM-data-sample.csv", split="train[:2000]")
def add_combined_column_exploretom(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        # Create combined text
        example["messages"] = [
        {"role": "user", "content": f"Story: {example['story_structure']} Question: {example['question']}"},
        {"role": "assistant", "content": example['expected_answer']}
        ]
        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

dataset_exploretom = add_combined_column_exploretom(dataset_exploretom)
print(dataset_exploretom)
print(dataset_exploretom[0])


dataset_tombench = load_dataset("json", data_files=".../ToMbench_data/train_combined.json", split="train[:2420]")
def add_combined_column_tombench(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        # Create combined text
        option_A = example["OPTION-A"]
        option_B = example["OPTION-B"]
        option_C = example["OPTION-C"]
        option_D = example["OPTION-D"]

        formatted_string = ""
        formatted_string += "A: " + option_A + " "
        formatted_string += "B: " + option_B
        if option_C != None:
            formatted_string += " " + "C: " + option_C
        if option_D != None:
            formatted_string += " " + "D: " + option_D

        if example["答案\nANSWER"] == 'A':
            answer = option_A
        elif example["答案\nANSWER"] == 'B':
            answer = option_B
        elif example["答案\nANSWER"] == 'C':
            answer = option_C
        else:
            answer = option_D

        example["messages"] = [
        {"role": "user", "content": f"Story: {example['STORY']} Question: {example['QUESTION']} Choices: {formatted_string}"},
        {"role": "assistant", "content": answer}
        ]
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

dataset_tombench = add_combined_column_tombench(dataset_tombench)
print(dataset_tombench)
print(dataset_tombench[0])

dataset_socialqa = load_dataset("json", data_files=".../SocialIqa/socialIWa_v1.4_trn_wDims.json", split="train[:2000]")
dataset_socialqa = dataset_socialqa.remove_columns(['charmap'])

def add_combined_column_socialqa(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        # Create combined text
        answer_A = example["answerA"]
        answer_B = example["answerB"]
        answer_C = example["answerC"]
       
        formatted_string = ""
        formatted_string += "A: " + answer_A + " " + "B: " + answer_B + " " + "C: " + answer_C
                    
        if example["label_letter"] == "A":
            answer = answer_A
        elif example["label_letter"] == "B":
            answer = answer_B
        else:
            answer = answer_C

        example["messages"] = [
        {"role": "user", "content": f"Context: {example['context']} Question: {example['question']} Choices: {formatted_string}"},
        {"role": "assistant", "content": answer}
        ]
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)


dataset_socialqa = add_combined_column_socialqa(dataset_socialqa)
print(dataset_socialqa)
print(dataset_socialqa[0])

train_dataset = concatenate_datasets([dataset_tomi, dataset_hitom, dataset_exploretom, dataset_tombench, dataset_socialqa])
train_dataset.save_to_disk(".../SFTData/direct_test")
train_dataset.load_from_disk(".../SFTData/direct_test")
print(train_dataset)
print(train_dataset[0])
